import torch
from torch import nn
from torch.nn import AdaptiveAvgPool2d
from torchsummary import summary

class Residual(nn.Module):#定义残差块
    def __init__(self,input_channels,num_channels,use_1conv=False,strides=1):#传入参数 使用1x1的卷积
        super(Residual, self).__init__()
        self.ReLU=nn.ReLU()
        self.conv1=nn.Conv2d(in_channels=input_channels,out_channels=num_channels,kernel_size=3,padding=1,stride=strides)
        self.conv2=nn.Conv2d(in_channels=num_channels,out_channels=num_channels,kernel_size=3,padding=1)
        self.bn1=nn.BatchNorm2d(num_channels)
        self.bn2=nn.BatchNorm2d(num_channels)
        if use_1conv:
            self.conv3=nn.Conv2d(in_channels=input_channels,out_channels=num_channels,kernel_size=1,stride=strides)
        else:
            self.conv3=None

    def forward(self,x):

        y=self.ReLU(self.bn1(self.conv1(x)))
        y=self.bn2(self.conv2(y))
        if self.conv3 is not None:#有些残差块旁支带卷积
            x=self.conv3(x)
        y=self.ReLU(y+x)

        return y

class ResNet18(nn.Module):
    def __init__(self,Residual):
        super(ResNet18,self).__init__()
        #定义块
        self.b1=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3),
            nn.ReLU(),
            # BN层参数是前面输出通道数。一般pytorch会给权重初始化，有时会很极端所以自己手动凯明初始化，BN层加入可以不用手动初始化参数
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
        #第1、2个残差块定义为一个块
        self.b2=nn.Sequential(
            Residual(64,64,use_1conv=False,strides=1),
            Residual(64,64,use_1conv=False,strides=1))
        self.b3=nn.Sequential(
            Residual(64,128,use_1conv=True,strides=2),
            Residual(128,128,use_1conv=False,strides=1))
        self.b4 = nn.Sequential(
            Residual(128, 256, use_1conv=True, strides=2),
            Residual(256, 256, use_1conv=False, strides=1))
        self.b5 = nn.Sequential(
            Residual(256, 512, use_1conv=True, strides=2),
            Residual(512, 512, use_1conv=False, strides=1))

        self.b6=nn.Sequential(AdaptiveAvgPool2d((1,1)),#全局平均池化
                              nn.Flatten(),
                              nn.Linear(512,2))#戴口罩和不戴口罩是2分类
    def forward(self,x):
        x=self.b1(x)
        x=self.b2(x)
        x=self.b3(x)
        x=self.b4(x)
        x=self.b5(x)
        x=self.b6(x)
        return x

if __name__ == '__main__':
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ResNet18(Residual).to(device)
    print(summary(model,(1,224,224)))
